{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "from datetime import datetime\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.autograd import Variable\n",
    "\n",
    "from LipReadDataTrain import ReadData\n",
    "from LipNet import LipNet, LipSeqLoss"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "===============================================\n",
    "           1.Data \n",
    "==============================================="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image_file = os.path.join(os.path.abspath('.'), \"data/lip_train\")\n",
    "train_label_file = os.path.join(os.path.abspath('.'), \"data/lip_train.txt\")\n",
    "training_dataset = ReadData(train_image_file, train_label_file, seq_max_lens=24)\n",
    "training_data_loader = DataLoader(training_dataset, batch_size=20, shuffle=True, num_workers=12, drop_last=True)\n",
    "\n",
    "# GPU\n",
    "device = torch.device('cuda:0')\n",
    "# # CPU \n",
    "# device = torch.device('cpu')\n",
    "\n",
    "model = LipNet().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "loss_fc = LipSeqLoss().to(device)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "===============================================\n",
    "            2.Training\n",
    "==============================================="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1,1000):\n",
    "    print(epoch)\n",
    "    model.train()\n",
    "    for i_batch, sample_batched in enumerate(training_data_loader):\n",
    "        \n",
    "        input_data = Variable(sample_batched['volume']).to(device) \n",
    "        labels = Variable(sample_batched['label']).to(device)\n",
    "        length = Variable(sample_batched['length']).to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        result = model(input_data)  \n",
    "        loss = loss_fc(result, length, labels)   \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "    if epoch % 5 == 0:\n",
    "        current_time = datetime.now()\n",
    "        print(\"current time:\", current_time)\n",
    "        print(\"number of epoch:\", epoch)\n",
    "        print(\"current loss:\", loss)\n",
    "                   \n",
    "        # save model\n",
    "        torch.save(model.state_dict(), \"./weight/demo_net_epoch_{}.pt\".format(epoch))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
